{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cell and Tissue Search Tutorial\n",
    "KRONOS embedding can be used to search and retrieve cell or tissue samples with similar phenotypic or spatial patterns.  \n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "To follow this tutorial, ensure you have the following data prepared:\n",
    "\n",
    "1. **Query cell/patch embeddings**: A folder containing query cell/patch embeddings, where each embedding is a numpy array stored in a npy file.\n",
    "2. **Support cell/patch embeddings**: A folder containing support cell/patch embeddings, similar to query embeddings folder.\n",
    "\n",
    "\n",
    "**Notes**: Refer to **[2 - Cell-phenotyping.ipynb](https://github.com/mahmoodlab/KRONOS/blob/main/tutorials/2%20-%20Cell-phenotyping.ipynb)** and **[3 - Patch-phenotyping.ipynb](https://github.com/mahmoodlab/KRONOS/blob/main/tutorials/3%20-%20Patch-phenotyping.ipynb)** tutorials for cell and patch embedding extractions.\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Import Required Packages\n",
    "\n",
    "We begin by importing the necessary libraries and modules for the workflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Experiment Configuration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Configure retrieval experiment settings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration dictionary containing all parameters for the pipeline\n",
    "config = {\n",
    "    \"query_folder\": f\"/data/query\", # Replace with your actual query folder\n",
    "    \"support_folder\": f\"/data/support\", # Replace with your actual support folder\n",
    "    \"results_path\": f\"retrieval_results.csv\", # Replace with your actual results path\n",
    "    'device': 'cuda:0', # Replace with the device you want to use.\n",
    "    # Retrieval-related parameters\n",
    "    'topk': 5, # How many samples do you want to retrieve from the support\n",
    "    'similarity': 'l2', # use l2 distance or cosine similarity to calculate the similarity scores; use l2 distance by default\n",
    "    'centering': True # whether to center the features before retrieval\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Define Retrieval Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define a function that takes the above defined configs as input, perform the retrieval, and saves the topk retrieved samples for each query sample in a csv file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrieval(query_folder, key_folder, results_path, device='cuda:0', topk=5, similarity='l2', centering=True):\n",
    "    # load query and support data from their respective folders\n",
    "    query_filenames = [file for file in os.listdir(query_folder) if file.endswith('.npy')]\n",
    "    queries = np.array([np.load(os.path.join(query_folder, file)) for file in query_filenames])\n",
    "    support_filenames = [file for file in os.listdir(key_folder) if file.endswith('.npy')]\n",
    "    keys = np.array([np.load(os.path.join(key_folder, filename)) for filename in support_filenames])\n",
    "    assert similarity in ['cosine', 'l2']\n",
    "    \n",
    "    # move both queries and supports to device\n",
    "    queries = torch.from_numpy(queries).to(device)\n",
    "    keys = torch.from_numpy(keys).to(device)\n",
    "\n",
    "    # preprocessing for centering and normalization\n",
    "    if centering:\n",
    "        means = keys.mean(dim=0, keepdim=True)\n",
    "        keys = keys - means\n",
    "        queries = queries - means\n",
    "    if similarity == 'cosine' or centering:\n",
    "        queries = F.normalize(queries.float(), dim=1)\n",
    "        keys = F.normalize(keys.float(), dim=1)\n",
    "    \n",
    "    # calculate similarities\n",
    "    if similarity == 'cosine':\n",
    "        sim_scores = torch.matmul(queries, keys.T)\n",
    "    elif similarity == 'l2':\n",
    "        sim_scores = -torch.cdist(queries, keys, p=2) # take negative to make it a similarity\n",
    "    else:\n",
    "        raise ValueError(f'similarity {similarity} not supported!')\n",
    "\n",
    "    # obtain topk retrieved indices\n",
    "    _, topk_ids = torch.topk(sim_scores, max(ks), dim=1)\n",
    "    topk_filenames = np.array(support_filenames)[topk_ids]\n",
    "\n",
    "    # make the result csv\n",
    "    topk_df = pd.DataFrame(topk_filenames, columns=[f'top{k+1}' for k in range(topk_filenames.shape[1])])\n",
    "    topk_df.index = query_filenames\n",
    "    topk_df.to_csv(results_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Perform Actual Retrieval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we run the actual retrieval using the defined configs. The resulting csv file saved in the `results_path` defined in `config` contains the top-k retrieved file names in the support set, for each query. It will be similar to\n",
    "\n",
    "||top1|top2|top3|top4|top5|\n",
    "|--|--|--|--|--|--|\n",
    "|query_tissue_1.npy|support_tissue_10.npy|support_tissue_2.npy|support_tissue_5.npy|support_tissue_26.npy|support_tissue_23.npy|\n",
    "|...|\n",
    "|query_tissue_n.npy|support_tissue_50.npy|support_tissue_36.npy|support_tissue_25.npy|support_tissue_46.npy|support_tissue_33.npy|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrieval(**config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kronos_v1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
